{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python38364bitd9422d0e0f224c0ebaa082ef5b357e74",
   "display_name": "Python 3.8.3 64-bit"
  },
  "metadata": {
   "interpreter": {
    "hash": "8c207cf98bc84bffc44ccf9494458b90a53a5c17a03e11155797855f32f2c5f8"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[-0.2514],\n",
       "        [-0.2729]], grad_fn=<AddmmBackward>)"
      ]
     },
     "metadata": {},
     "execution_count": 2
    }
   ],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "\n",
    "net = nn.Sequential(nn.Linear(4,8), nn.ReLU(), nn.Linear(8, 1))\n",
    "X = torch.rand(size=(2, 4))\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "OrderedDict([('weight', tensor([[-0.1865, -0.2772, -0.3306, -0.3216, -0.1062,  0.0422, -0.1537,  0.1329]])), ('bias', tensor([-0.2960]))])\n"
     ]
    }
   ],
   "source": [
    "print(net[2].state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "<class 'torch.nn.parameter.Parameter'>\nParameter containing:\ntensor([-0.2960], requires_grad=True)\ntensor([-0.2960])\n"
     ]
    }
   ],
   "source": [
    "print(type(net[2].bias))\n",
    "print(net[2].bias)\n",
    "print(net[2].bias.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "metadata": {},
     "execution_count": 5
    }
   ],
   "source": [
    "net[2].weight.grad == None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n"
     ]
    }
   ],
   "source": [
    "print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n",
    "print(*[(name, param.shape) for name, param in net.named_parameters()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[0.6020],\n",
       "        [0.6020]], grad_fn=<AddmmBackward>)"
      ]
     },
     "metadata": {},
     "execution_count": 7
    }
   ],
   "source": [
    "def block1():\n",
    "    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU())\n",
    "\n",
    "def block2():\n",
    "    net = nn.Sequential()\n",
    "    for i in range(4):\n",
    "        net.add_module(f'block{i}', block1())\n",
    "    return net\n",
    "\n",
    "rgnet = nn.Sequential(block2(), nn.Linear(4, 1))\n",
    "rgnet(X) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Sequential(\n  (0): Sequential(\n    (block0): Sequential(\n      (0): Linear(in_features=4, out_features=8, bias=True)\n      (1): ReLU()\n      (2): Linear(in_features=8, out_features=4, bias=True)\n      (3): ReLU()\n    )\n    (block1): Sequential(\n      (0): Linear(in_features=4, out_features=8, bias=True)\n      (1): ReLU()\n      (2): Linear(in_features=8, out_features=4, bias=True)\n      (3): ReLU()\n    )\n    (block2): Sequential(\n      (0): Linear(in_features=4, out_features=8, bias=True)\n      (1): ReLU()\n      (2): Linear(in_features=8, out_features=4, bias=True)\n      (3): ReLU()\n    )\n    (block3): Sequential(\n      (0): Linear(in_features=4, out_features=8, bias=True)\n      (1): ReLU()\n      (2): Linear(in_features=8, out_features=4, bias=True)\n      (3): ReLU()\n    )\n  )\n  (1): Linear(in_features=4, out_features=1, bias=True)\n)\n"
     ]
    }
   ],
   "source": [
    "print(rgnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([-0.1582,  0.0295, -0.1049, -0.0131, -0.3809,  0.3023, -0.2920,  0.0601])"
      ]
     },
     "metadata": {},
     "execution_count": 9
    }
   ],
   "source": [
    "rgnet[0][1][0].bias.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(tensor([-0.0001,  0.0050,  0.0078,  0.0090]), tensor(0.))"
      ]
     },
     "metadata": {},
     "execution_count": 10
    }
   ],
   "source": [
    "def init_normal(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.normal_(m.weight, mean=0, std=0.01)\n",
    "        nn.init.zeros_(m.bias)\n",
    "\n",
    "net.apply(init_normal)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(tensor([1., 1., 1., 1.]), tensor(0.))"
      ]
     },
     "metadata": {},
     "execution_count": 12
    }
   ],
   "source": [
    "def init_constant(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.constant_(m.weight, 1)\n",
    "        nn.init.zeros_(m.bias)\n",
    "\n",
    "net.apply(init_constant)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "tensor([ 0.6460, -0.3767, -0.2095,  0.5924])\ntensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n"
     ]
    }
   ],
   "source": [
    "def xavier(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(m.weight)\n",
    "    \n",
    "def init_42(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.constant_(m.weight, 42)\n",
    "\n",
    "net[0].apply(xavier)\n",
    "net[2].apply(init_42)\n",
    "print(net[0].weight.data[0])\n",
    "print(net[2].weight.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "init weight torch.Size([8, 4])\ninit weight torch.Size([1, 8])\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[ 0.0000, -9.3767, -5.7817,  5.9173],\n",
       "        [ 9.1381, -0.0000, -0.0000,  8.7322]], grad_fn=<SliceBackward>)"
      ]
     },
     "metadata": {},
     "execution_count": 14
    }
   ],
   "source": [
    "def my_init(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        print(\n",
    "            'init',\n",
    "            *[(name, param.shape) for name, param in m.named_parameters()][0]\n",
    "        )\n",
    "        nn.init.uniform_(m.weight, -10, 10)\n",
    "        m.weight.data *= m.weight.data.abs() >= 5\n",
    "\n",
    "net.apply(my_init)\n",
    "net[0].weight[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([42.0000, -8.3767, -4.7817,  6.9173])"
      ]
     },
     "metadata": {},
     "execution_count": 15
    }
   ],
   "source": [
    "net[0].weight.data[:] += 1 \n",
    "net[0].weight.data[0, 0] = 42 \n",
    "net[0].weight.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "tensor([True, True, True, True, True, True, True, True])\ntensor([True, True, True, True, True, True, True, True])\n"
     ]
    }
   ],
   "source": [
    "shared = nn.Linear(8, 8)\n",
    "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), shared, nn.ReLU(), shared, nn.ReLU(), nn.Linear(8, 1))\n",
    "net(X)\n",
    "print(net[2].weight.data[0] == net[4].weight.data[0])\n",
    "net[2].weight.data[0, 0] = 100\n",
    "print(net[2].weight.data[0] == net[4].weight.data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn.functional as F \n",
    "from torch import nn \n",
    "\n",
    "class MyCenteredLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, X):\n",
    "        return X - X.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([-2., -1.,  0.,  1.,  2.])"
      ]
     },
     "metadata": {},
     "execution_count": 24
    }
   ],
   "source": [
    "layer = MyCenteredLayer()\n",
    "layer(torch.FloatTensor([1, 2, 3, 4, 5]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor(9.3132e-10, grad_fn=<MeanBackward0>)"
      ]
     },
     "metadata": {},
     "execution_count": 26
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(8, 128), MyCenteredLayer())\n",
    "Y = net(torch.rand(4, 8))\n",
    "Y.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyLinear(nn.Module):\n",
    "    def __init__(self, in_units, units):\n",
    "        super().__init__()\n",
    "        self.weight = nn.Parameter(torch.randn(in_units, units))\n",
    "        self.bias = nn.Parameter(torch.randn(units,))\n",
    "\n",
    "    def forward(self, X):\n",
    "        linear = torch.matmul(X, self.weight.data) + self.bias.data \n",
    "        return F.relu(linear)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-0.2872,  1.7291,  1.1134],\n",
       "        [ 0.1566,  0.2317,  1.9818],\n",
       "        [-1.3092, -0.3687, -0.3109],\n",
       "        [-0.2282,  0.0694,  0.7117],\n",
       "        [-0.3479, -1.4556,  0.0203]], requires_grad=True)"
      ]
     },
     "metadata": {},
     "execution_count": 28
    }
   ],
   "source": [
    "dense = MyLinear(5, 3)\n",
    "dense.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.0000, 1.0288],\n",
       "        [0.0000, 0.2081, 1.3999]])"
      ]
     },
     "metadata": {},
     "execution_count": 29
    }
   ],
   "source": [
    "dense(torch.rand(2, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[2.7761],\n",
       "        [4.6184]])"
      ]
     },
     "metadata": {},
     "execution_count": 30
    }
   ],
   "source": [
    "net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))\n",
    "net(torch.rand(2, 64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "from torch.nn import functional as F \n",
    "\n",
    "x = torch.arange(4)\n",
    "y = x\n",
    "torch.save(x, 'x-file')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([0, 1, 2, 3])"
      ]
     },
     "metadata": {},
     "execution_count": 35
    }
   ],
   "source": [
    "x2 = torch.load('x-file')\n",
    "x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'x': tensor([0, 1, 2, 3]), 'y': tensor([0, 1, 2, 3])}"
      ]
     },
     "metadata": {},
     "execution_count": 36
    }
   ],
   "source": [
    "mydict = {'x':x, 'y':y}\n",
    "torch.save(mydict, 'mydict')\n",
    "mydict2 = torch.load('mydict')\n",
    "mydict2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden = nn.Linear(20, 256)\n",
    "        self.output = nn.Linear(256, 10)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.output(F.relu(self.hidden(x)))\n",
    "    \n",
    "net = MLP()\n",
    "X = torch.randn(size=(2, 20))\n",
    "Y = net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(net.state_dict(), 'mlp.params')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "MLP(\n",
       "  (hidden): Linear(in_features=20, out_features=256, bias=True)\n",
       "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "metadata": {},
     "execution_count": 40
    }
   ],
   "source": [
    "clone = MLP()\n",
    "clone.load_state_dict(torch.load('mlp.params'))\n",
    "clone.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[True, True, True, True, True, True, True, True, True, True],\n",
       "        [True, True, True, True, True, True, True, True, True, True]])"
      ]
     },
     "metadata": {},
     "execution_count": 41
    }
   ],
   "source": [
    "Y_clone = clone(X)\n",
    "Y == Y_clone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}